package com.gmcloud.common.security.config;

import com.gmcloud.common.security.config.oauth2.PermitAllUrlProperties;
import com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.server.resource.BearerTokenError;
import org.springframework.security.oauth2.server.resource.BearerTokenErrors;
import org.springframework.security.oauth2.server.resource.web.BearerTokenResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @author zl.sir
 * @version 1.0
 * @since 2022/8/17 12:14
 * 请求令牌的抽取逻辑
 */
public class GmBearerTokenExtractor implements BearerTokenResolver {

    private static final Logger LOGGER = LoggerFactory.getLogger(GmBearerTokenExtractor.class);

    private static final Pattern authorizationPattern = Pattern.compile("^Bearer (?<token>[a-zA-Z0-9-:._~+/]+=*)$",
            Pattern.CASE_INSENSITIVE);
    private String bearerTokenHeaderName = HttpHeaders.AUTHORIZATION;

    private final PermitAllUrlProperties urlProperties;

    private boolean allowFormEncodedBodyParameter = false;

    private boolean allowUriQueryParameter = false;


    public GmBearerTokenExtractor(PermitAllUrlProperties urlProperties) {
        this.urlProperties = urlProperties;
    }


    /**
     * 请求中解析令牌,获取token
     *
     * @param request 请求
     * @return string
     */
    @Override
    public String resolve(HttpServletRequest request) {
        boolean match = checkIgnores(request);

        if (match) {
            return null;
        }
        // 验证header的token
        final String authorizationHeaderToken = resolveFromAuthorizationHeader(request);
        // 验证param的token
        final String parameterToken = Boolean.TRUE.equals(isParameterTokenSupportedForRequest(request))
                ? resolveFromRequestParameters(request) : null;

        if (authorizationHeaderToken != null) {
            if (parameterToken != null) {
                final BearerTokenError error = BearerTokenErrors
                        .invalidRequest("Found multiple bearer tokens in the request");
                throw new OAuth2AuthenticationException(error);
            }
            return authorizationHeaderToken;
        }
        if (parameterToken != null && isParameterTokenEnabledForRequest(request)) {
            return parameterToken;
        }
        return null;
    }

    /**
     * 验证token是否存在参数param里面的
     * @param request 请求
     * @return boolean
     */
    private Boolean isParameterTokenSupportedForRequest(HttpServletRequest request) {
        return (("POST".equals(request.getMethod())
                && MediaType.APPLICATION_FORM_URLENCODED_VALUE.equals(request.getContentType()))
                || "GET".equals(request.getMethod()));
    }

    private  String resolveFromRequestParameters(HttpServletRequest request) {
        String[] values = request.getParameterValues("access_token");
        if (values == null || values.length == 0) {
            return null;
        }
        if (values.length == 1) {
            return values[0];
        }
        BearerTokenError error = BearerTokenErrors.invalidRequest("Found multiple bearer tokens in the request");
        throw new OAuth2AuthenticationException(error);
    }

    /**
     * 验证header里的token
     *
     * @param request 请求1
     * @return string
     */
    private String resolveFromAuthorizationHeader(HttpServletRequest request) {
        String authorization = request.getHeader(this.bearerTokenHeaderName);

        if (!StringUtils.startsWithIgnoreCase(authorization, "bearer")) {
            return null;
        }

        Matcher matcher = authorizationPattern.matcher(authorization);

        if (!matcher.matches()) {
            BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed");
            throw new OAuth2AuthenticationException(error);
        }
        return matcher.group("token");
    }

    private boolean isParameterTokenEnabledForRequest(final HttpServletRequest request) {
        return ((this.allowFormEncodedBodyParameter && "POST".equals(request.getMethod())
                && MediaType.APPLICATION_FORM_URLENCODED_VALUE.equals(request.getContentType()))
                || (this.allowUriQueryParameter && "GET".equals(request.getMethod())));
    }

    /**
     * 请求是否不需要进行权限拦截
     *
     * @param request 当前请求
     * @return true - 忽略，false - 不忽略
     */
    private boolean checkIgnores(HttpServletRequest request) {
        String method = request.getMethod();

        Set<String> ignores = Sets.newHashSet();

        // 这是配置文件读取白名单，留用
        HttpMethod httpMethod = HttpMethod.resolve(method);

        if (Objects.isNull(httpMethod)) {
            httpMethod = HttpMethod.GET;
        }

        if (Objects.nonNull(urlProperties.getIgnore())) {
            switch (httpMethod) {
                case GET:
                    ignores.addAll(urlProperties.getIgnore().getGet());
                    break;
                case PUT:
                    ignores.addAll(urlProperties.getIgnore().getPut());
                    break;
                case HEAD:
                    ignores.addAll(urlProperties.getIgnore().getHead());
                    break;
                case POST:
                    ignores.addAll(urlProperties.getIgnore().getPost());
                    break;
                case PATCH:
                    ignores.addAll(urlProperties.getIgnore().getPatch());
                    break;
                case TRACE:
                    ignores.addAll(urlProperties.getIgnore().getTrace());
                    break;
                case DELETE:
                    ignores.addAll(urlProperties.getIgnore().getDelete());
                    break;
                case OPTIONS:
                    ignores.addAll(urlProperties.getIgnore().getOptions());
                    break;
                default:
                    LOGGER.info("没有白名单配置");
                    break;
            }

            ignores.addAll(urlProperties.getIgnore().getPattern());

        }
        if (!ignores.isEmpty()) {
            for (String ignore : ignores) {
                AntPathRequestMatcher matcher = new AntPathRequestMatcher(ignore, method);

                if (matcher.matches(request)) {
                    return true;
                }
            }
        }
        return false;
    }
}
